rm(list=ls())
source('param_general.R')


source('./SwitchCosts/stata_results_switch.R')
## Paper Table 2, Age Restriction vs. Honore 
DISTANCE = c(coefDistList$main,coefDistList$naive, coefDistList$naive)
SWITCHING.COST =  c(coefSwitchList$main,coefSwitchList$naive, 0)


dataPat <- readRDS(paste0(datDir,"ld_women_process_all.rds"))
dataPat = copy(dataPat[,list(maskssn,hosp_id,pat_zip,age,yearDat)])
setnames(dataPat,"yearDat","year")
setkey(dataPat, maskssn,year)
dataPat[,':='(mult_admit=.N,birthNum = .I - min(.I) + 1L),by=list(maskssn)]
dataPat[mult_admit>1,':='(hosp_id_prev=c(NA,head(hosp_id,-1)),
                      pat_zip_prev=c(NA,head(pat_zip,-1)),
                      year_prev=c(NA,head(year,-1))),
    by=list(maskssn)]
dataPat[,years_since_birth:= year - year_prev ]

PROB_NEW_MOTHER = length(dataPat[year == 2014 & birthNum ==1]$birthNum)/length(dataPat[year == 2014]$birthNum)

PROB_OLD_MOTHER = table(dataPat[year == 2014 & birthNum >1 & years_since_birth > 0]$years_since_birth)/length(dataPat[year == 2014 & birthNum > 1 & years_since_birth > 0]$birthNum)

source("./SwitchCosts/create_delta_conlaw.R")

dataPatPSA = copy(dataPat[pat_zip %in% zipPSA,])

## Now create cohorts ...

dataCohortPSA = copy(dataPatPSA[birthNum == mult_admit,list(maskssn,hosp_id,pat_zip,age,year,birthNum)])
setnames(dataCohortPSA,"hosp_id","hosp_id_alt")
dataCohortPSA[,hosp_id_alt := as.character(hosp_id_alt) ]
dataCohortPSA = merge(dataCohortPSA,hospitalCrosswalk,by = c("hosp_id_alt"))
dataCohortPSA[!(hospital %in% hospitalsInMarket), hosp_id_alt :="0"]
dataCohortPSA[, hospital := NULL]

setnames(dataCohortPSA,c("hosp_id_alt","year","age"),c("hosp_id_prev","year_prev","age_prev"))

## Do 2015
## Now cohort demand based on 2013 actuals
## Adjusted to future as in crete_delta_conlaw.R
demandTotal = dataPatPSA[year == 2013,.N,by=c("pat_zip")]
setnames(demandTotal,"N","nZipNew")
demandTotal[pat_zip %in% ZIP_CLOSER,nZipNew := nZipNew * POP_ADJ_CLOSER]
demandTotal[pat_zip %in% ZIP_CONSTANT,nZipNew := nZipNew * POP_ADJ_CONSTANT]
demandTotal[pat_zip %in% ZIP_FARTHER,nZipNew := nZipNew * POP_ADJ_FARTHER]
demandTotal[,newMother := round(PROB_NEW_MOTHER*round(nZipNew))]

## New Mothers
newMotherCohort = NULL
for(xi in 1: length(demandTotal$pat_zip)) {
  temp = as.data.table(list(admission= 1:demandTotal[xi,newMother],pat_zip = demandTotal[xi,pat_zip]))
  temp[,maskssn := paste0(as.character(pat_zip),"-",as.character(admission))]
  newMotherCohort = rbind(newMotherCohort,temp[,list(maskssn,pat_zip)])
}
## Old Mothers
prevMotherProb =   as.data.frame(list(years_since_birth = 1:8,prob =  unique(PROB_OLD_MOTHER)))

demandOld = as.data.table(merge(as.data.frame(demandTotal),prevMotherProb,by= NULL))
demandOld[,demand := round(prob *(1-PROB_NEW_MOTHER) * round(nZipNew))]


dataPSAMktList = list(all = copy(dataPSAMkt), naive = copy(dataPSAMkt), noSC = copy(dataPSAMkt))


for(xi in 1:3) {
  dataPSAMktList[[xi]][,xiHat := deltaHat - DISTANCE[xi] * (time_current)]
  dataPSAMktList[[xi]] = dataPSAMktList[[xi]][,list(hosp_id_alt,pat_zip,deltaHat)]
  ## Add Outside Option
  dataPSAMktList[[xi]] = rbind(dataPSAMktList[[xi]],as.data.table(list(hosp_id_alt = "0", deltaHat = 0, pat_zip = dataPSAMkt[,unique(pat_zip)])))
}


createInitialCohort = function(dataCohortInitial,dataCohortNew,demandOld,YEAR.START,YEAR.END) {
  allBirthCohorts = NULL
  
  ## Initialize Old Cohort
  cohort = copy(dataCohortInitial)
  cohort[,years_since_birth := 2015 - year_prev]
  
  for(YEAR in YEAR.START:YEAR.END) {
    cohort = merge(cohort,demandOld[,list(pat_zip,years_since_birth,demand)], by= c("pat_zip","years_since_birth"))
    pregnant_women = cohort[,sample(maskssn,demand,FALSE),by = c("pat_zip","years_since_birth")]$V1
    
    cohort[,birth_this_year := (maskssn %in% pregnant_women)]

    ## Next Cohort
    nextCohortNew = copy(dataCohortNew)[,':='(year_prev = YEAR,birthNum = 1, years_since_birth = 1,hosp_id_prev = -1, birth_this_year = TRUE, maskssn = paste0(YEAR,"-",maskssn))]
    nextCohortOld = copy(cohort)
    nextCohortOld[,":="(demand = NULL,age_prev = NULL)]
    nextCohortOld[birth_this_year == TRUE,birthNum := birthNum + 1]
    birthCohort = rbind(copy(nextCohortOld)[birth_this_year == TRUE,],copy(nextCohortNew)[,':='(year_prev = NA, years_since_birth = 0)])
    birthCohort[,year := YEAR]
    nextCohortOld[birth_this_year == TRUE, ':='(year_prev = YEAR,years_since_birth = 1)]
    nextCohortOld[birth_this_year == FALSE, ':='(years_since_birth = years_since_birth + 1)]
    
    nextCohortOld = nextCohortOld[years_since_birth <= 8,]
    nextCohort = rbind(nextCohortNew,nextCohortOld)
    allBirthCohorts = rbind(allBirthCohorts,birthCohort)
    cohort = copy(nextCohort)
  }
  return(allBirthCohorts)
  
}

doSimulation = function(ITER,data,YEAR.START,YEAR.END,DISTANCE, SWITCHING.COST) {
  resultsList= list(new = list(),old= list(), exclusion = list())
    for(xj in 1:2) {
      dataSimul = copy(data)
        for(YEAR in YEAR.START:YEAR.END) {
          if(xj == 1) {
            dataSimul[year == YEAR, detUtil := deltaHat + SWITCHING.COST * (chosenPrev == TRUE) + (-1e10) * (hosp_id_alt == "100167")]
          }
          if(xj == 2) {
            dataSimul[year == YEAR, detUtil := deltaHat + SWITCHING.COST * (chosenPrev == TRUE)]
          }
          ## Outside option has no utility, no switching benefit
          dataSimul[year == YEAR & hosp_id_alt == 0, detUtil := 0]
          dataSimul[year == YEAR, hospShare := exp(detUtil)/sum(exp(detUtil)), by = c("maskssn","birthNum")]
          dataSimul[year == YEAR, hospShareCumSum := cumsum(hospShare), by = c("maskssn","birthNum")]
          dataSimul[year == YEAR, obs_choice := (rand<= hospShareCumSum & rand > (hospShareCumSum - hospShare))]
          dataSimul[year == YEAR,utility := log(0+sum(exp(detUtil))),by = c("maskssn","birthNum")] 
          setkey(dataSimul,maskssn,hosp_id_alt,year)
   
          dataSimul[, chosenPrevNew:= c(NA,head(obs_choice,-1)), by = list(maskssn,hosp_id_alt)]
          dataSimul[!(is.na(chosenPrevNew) | hosp_id_alt == "0"), chosenPrev := chosenPrevNew]
        }
        resultsList[[xj]]= dataSimul[hosp_id_alt == "100167", list(util = sum(utility),
                                                                                        utilOld = sum(utility * (birthNum >1)),
                                                                                        share = sum(obs_choice)/.N,
                                                                                        shareOld = sum(obs_choice *(birthNum >1))/sum(birthNum >1),
                                                                                        N = .N,
                                                                                        NOld = sum(birthNum >1)), by = c("pat_zip","year")]
      }
  return(resultsList)
  }

doSimulationWrap = function(ITER,dataCohortInitial,dataCohortNew,demandOld,dataPSAMktList,DISTANCE,SWITCHING.COST,YEAR.START,YEAR.END) {
    dataSimulList= list(list(new = list(),old= list(), exclusion = list()),list(new = list(),old= list(),exclusion = list()),list(new = list(),old= list(),exclusion = list()))
  for(xi in 1:3) {
    set.seed(123457 + ITER)
    allBirthCohorts = createInitialCohort(dataCohortInitial,dataCohortNew,demandOld,YEAR.START,YEAR.END)

    dataMerged = copy(merge(allBirthCohorts,dataPSAMktList[[xi]], by = c("pat_zip"), allow.cartesian = TRUE))
    dataMerged[,chosenPrev := hosp_id_prev == hosp_id_alt]
    dataMerged[hosp_id_alt == "0",chosenPrev := FALSE]
    dataMerged[,rand := runif(1), by = c("maskssn","birthNum")]
    results = doSimulation(ITER = ITER,data = dataMerged,DISTANCE = DISTANCE[xi],SWITCHING.COST = SWITCHING.COST[xi],YEAR.START = YEAR.START,YEAR.END = YEAR.END)
    for(xj in 1:2) {
      dataSimulList[[xi]][[xj]] = results[[xj]]
    }
  }
    return(dataSimulList)
}


YEAR.START = 2015
YEAR.END = 2025


time.start = proc.time()
ITER.LIMIT = 100
simul = mclapply(1:ITER.LIMIT,
    function(x) doSimulationWrap(ITER = x,dataCohortInitial = dataCohortPSA,dataCohortNew = newMotherCohort,demandOld = demandOld,dataPSAMktList = dataPSAMktList,DISTANCE = DISTANCE, SWITCHING.COST = SWITCHING.COST,YEAR.START = YEAR.START,YEAR.END = YEAR.END),
    mc.cores = 50)
time.finish = proc.time()
time.finish - time.start

saveRDS(simul,paste0(resultsDir,"/simulList.rds"))


simulResults = rbind(cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[1]][[1]])), CF = "exclusion", type = "Fixed Effect"),
                     cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[2]][[1]])), CF = "exclusion", type = "Lagged Dep Var"),
                      cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[3]][[1]])), CF = "exclusion", type = "Standard Logit"),
                     cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[1]][[2]])), CF= "original", type = "Fixed Effect"),
                     cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[2]][[2]])), CF = "original", type = "Lagged Dep Var"),
                     cbind(do.call("rbind",lapply(1:ITER.LIMIT,function(x) simul[[x]][[3]][[2]])), CF = "original", type = "Standard Logit")
                     )

simulResults[,utilNew := util - utilOld]
estChangeCF <- function(x,CF,CF_VALUE) {
  change = sum(x*(CF == CF_VALUE))/sum(x*(CF == "original"))
  return((change - 1)*100)
}
createSimulChange = function(simulResults,CF_VALUE, ZIP) {
  if(ZIP == "farther") {
     data = simulResults[pat_zip %in% ZIP_FARTHER,]
 }
   if(ZIP == "closer") {
     data = simulResults[pat_zip %in% ZIP_CLOSER,]
 }
   if(ZIP == "constant") {
     data = simulResults[pat_zip %in% ZIP_CONSTANT,]
 }
   if(ZIP == "all") {
     data = simulResults
 }
  
  simulResultsChange = data[,list(util = estChangeCF(util,CF,CF_VALUE),utilOld = estChangeCF(utilOld,CF,CF_VALUE),utilNew= estChangeCF(utilNew,CF,CF_VALUE),
                                  share = estChangeCF(share*N,CF,CF_VALUE),shareOld = estChangeCF(shareOld*NOld,CF,CF_VALUE), shareNew = estChangeCF(share*N - shareOld*NOld,CF,CF_VALUE)), by = c("year","type")]
  return(simulResultsChange)
}



createGraph = function(simulResults,CF_VALUE,ZIP,VAR) {
  simulResultsChange = createSimulChange(simulResults,CF_VALUE,ZIP)
  simulResultsChange[,changeOld:= get(paste0(VAR,"Old"))]

  simulResultsChange[,changeNew:= mean(get(paste0(VAR,"New"))), by = c("type")]
  
  changeNew = simulResultsChange[,get(paste0(VAR,"New"))]
  graph = ggplot(data = simulResultsChange,aes(x = year - 2015, y = changeOld, color = type)) + geom_line(size = 2) + theme_bw() +
#    +
    theme(axis.title.x = element_text(face = "bold", size = 14, vjust = .1), axis.text.x = element_text(size = 10)) +
    theme(axis.title.y = element_text(face = "bold", size = 14, vjust = .1), axis.text.y = element_text(size = 10)) + 
    theme(panel.grid.major = element_blank(),
          panel.grid.minor = element_blank() ) +
    theme(legend.position = "bottom", legend.box = "vertical") + 
    theme(legend.title = element_blank(), legend.text = element_text(size = 12)) +
    labs(x = "Year", y = "Percent Change") + scale_x_continuous(breaks = c(0, 2, 4, 6, 8, 10)) +scale_y_continuous(breaks = pretty_breaks(n=6)) + scale_color_manual(values = c("red","blue","green"))

  ggsave(filename = paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_Pres",".png"),plot = graph, width = widthPres,
           height = heightPres, dpi = dpiPlot, units = unitsPres, type = "cairo")
  ggsave(filename =  paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_Paper",".png"),plot = graph, width = widthPaper,
         height = heightPaper, dpi = dpiPlot, units = unitsPaper, type = "cairo")
}


createGraph = function(simulResults,CF_VALUE,ZIP,VAR) {
  simulResultsChange = createSimulChange(simulResults,CF_VALUE,ZIP)
  simulResultsChange[,changeOld:= get(paste0(VAR,"Old"))]
  
  simulResultsChange[,changeNew:= mean(get(paste0(VAR,"New"))), by = c("type")]
  
  changeNew = simulResultsChange[,get(paste0(VAR,"New"))]
  graph = ggplot(data = simulResultsChange,aes(x = year - 2015, y = changeOld, line = type)) + geom_line(size = 2) + theme_bw() +
    #    +
    theme(axis.title.x = element_text(face = "bold", size = 14, vjust = .1), axis.text.x = element_text(size = 10)) +
    theme(axis.title.y = element_text(face = "bold", size = 14, vjust = .1), axis.text.y = element_text(size = 10)) + 
    theme(panel.grid.major = element_blank(),
          panel.grid.minor = element_blank() ) +
    theme(legend.position = "bottom", legend.box = "vertical") + 
    theme(legend.title = element_blank(), legend.text = element_text(size = 12)) +
    labs(x = "Year", y = "Percent Change") + scale_x_continuous(breaks = c(0, 2, 4, 6, 8, 10)) +scale_y_continuous(breaks = pretty_breaks(n=6)) + scale_color_manual(values = c("red","blue","green"))
  
  ggsave(filename = paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_Pres",".png"),plot = graph, width = widthPres,
         height = heightPres, dpi = dpiPlot, units = unitsPres, type = "cairo")
  ggsave(filename =  paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_Paper",".png"),plot = graph, width = widthPaper,
         height = heightPaper, dpi = dpiPlot, units = unitsPaper, type = "cairo")
}


createGraphBW = function(simulResults,CF_VALUE,ZIP,VAR) {
  simulResultsChange = createSimulChange(simulResults,CF_VALUE,ZIP)
  simulResultsChange[,changeOld:= get(paste0(VAR,"Old"))]
  
  simulResultsChange[,changeNew:= mean(get(paste0(VAR,"New"))), by = c("type")]
  
  changeNew = simulResultsChange[,get(paste0(VAR,"New"))]
  graph = ggplot(data = simulResultsChange,aes(x = year - 2015, y = changeOld, color = type)) + geom_line(size = 2) + theme_bw() +
    #    +
    theme(axis.title.x = element_text(face = "bold", size = 14, vjust = .1), axis.text.x = element_text(size = 10)) +
    theme(axis.title.y = element_text(face = "bold", size = 14, vjust = .1), axis.text.y = element_text(size = 10)) + 
    theme(panel.grid.major = element_blank(),
          panel.grid.minor = element_blank() ) +
    theme(legend.position = "bottom", legend.box = "vertical") + 
    theme(legend.title = element_blank(), legend.text = element_text(size = 12)) +
    labs(x = "Year", y = "Percent Change") + scale_x_continuous(breaks = c(0, 2, 4, 6, 8, 10)) +scale_y_continuous(breaks = pretty_breaks(n=6)) + scale_color_grey()
  
  ggsave(filename = paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_PresBW",".png"),plot = graph, width = widthPres,
         height = heightPres, dpi = dpiPlot, units = unitsPres, type = "cairo")
  ggsave(filename =  paste0(resultsDir,"/Switch/HospChange_",CF_VALUE,"_",ZIP,"_",VAR,"_PaperBW",".png"),plot = graph, width = widthPaper,
         height = heightPaper, dpi = dpiPlot, units = unitsPaper, type = "cairo")
}

createGraph(simulResults,"exclusion","all","util")
createGraphBW(simulResults,"exclusion","all","util")
createSimulChange(simulResults,"exclusion","all")[order(type,year),]




